Skip to content

Commit

Permalink
feat: Ollama::create
Browse files Browse the repository at this point in the history
  • Loading branch information
adrienbrault committed May 30, 2024
1 parent ce929c2 commit 9373c38
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 52 deletions.
26 changes: 21 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -226,17 +226,33 @@ Would be a good use case / showcase of this library/cli?

### Custom Models

#### Ollama

If you want to use an Ollama model that is not available in the enum, you can use the `Ollama::create` static method:

```php
use AdrienBrault\Instructrice\LLM\LLMConfig;
use AdrienBrault\Instructrice\LLM\Cost;
use AdrienBrault\Instructrice\LLM\OpenAiJsonStrategy;
use AdrienBrault\Instructrice\LLM\Provider\Ollama;

$instructrice->get(
...,
llm: Ollama::create(
'codestral:22b-v0.1-q5_K_M', // check its license first!
32000,
),
);
```

#### OpenAI

You can also use any OpenAI compatible api by passing an [LLMConfig](src/LLM/LLMConfig.php):

```php
use AdrienBrault\Instructrice\InstructriceFactory;
use AdrienBrault\Instructrice\LLM\LLMConfig;
use AdrienBrault\Instructrice\LLM\Cost;
use AdrienBrault\Instructrice\LLM\OpenAiLLM;
use AdrienBrault\Instructrice\LLM\OpenAiJsonStrategy;
use AdrienBrault\Instructrice\LLM\Provider\ProviderModel;
use AdrienBrault\Instructrice\Http\GuzzleStreamingClient;
use GuzzleHttp\Client;

$instructrice->get(
...,
Expand Down
22 changes: 18 additions & 4 deletions examples/bootstrap.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

declare(strict_types=1);
use AdrienBrault\Instructrice\InstructriceFactory;
use AdrienBrault\Instructrice\LLM\LLMConfig;
use AdrienBrault\Instructrice\LLM\Provider\Ollama;
use AdrienBrault\Instructrice\LLM\Provider\ProviderModel;
use Monolog\Logger;
use Psr\Log\LoggerInterface;
Expand Down Expand Up @@ -44,9 +46,20 @@ function createConsoleLogger(OutputInterface $output): LoggerInterface
$llmFactory = InstructriceFactory::createLLMFactory(logger: $logger);

if ($llm === null) {
$models = $llmFactory->getAvailableProviderModels();
$models[] = Ollama::create(
'codestral:22b-v0.1-q5_K_M',
32000,
);
$providerModels = reindex(
$llmFactory->getAvailableProviderModels(),
fn (ProviderModel $providerModel) => $providerModel->createConfig('123')->getLabel(),
$models,
function (ProviderModel|LLMConfig $providerModel) {
if ($providerModel instanceof LLMConfig) {
return $providerModel->getLabel();
}

return $providerModel->createConfig('123')->getLabel();
},
);

$questionSection = $output->section();
Expand All @@ -58,10 +71,11 @@ function createConsoleLogger(OutputInterface $output): LoggerInterface
));
$questionSection->clear();
$llm = $providerModels[$llmToUse];
assert($llm instanceof ProviderModel);
}

$output->writeln(sprintf('Using LLM: <info>%s</info>', $llm->createConfig('123')->getLabel()));
$llmConfig = $llm instanceof ProviderModel ? $llm->createConfig('123') : $llm;

$output->writeln(sprintf('Using LLM: <info>%s</info>', $llmConfig->getLabel()));
$output->writeln('');

$instructrice = InstructriceFactory::create(
Expand Down
2 changes: 1 addition & 1 deletion src/Instructrice.php
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class Instructrice
{
public function __construct(
private readonly ProviderModel $defaultLlm,
private readonly ProviderModel|LLMConfig $defaultLlm,
private readonly LLMFactory $llmFactory,
private readonly LoggerInterface $logger,
private readonly SchemaFactory $schemaFactory,
Expand Down
3 changes: 2 additions & 1 deletion src/InstructriceFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
use AdrienBrault\Instructrice\Http\GuzzleStreamingClient;
use AdrienBrault\Instructrice\Http\StreamingClientInterface;
use AdrienBrault\Instructrice\LLM\LLMChunk;
use AdrienBrault\Instructrice\LLM\LLMConfig;
use AdrienBrault\Instructrice\LLM\LLMFactory;
use AdrienBrault\Instructrice\LLM\Provider\Ollama;
use AdrienBrault\Instructrice\LLM\Provider\ProviderModel;
Expand Down Expand Up @@ -55,7 +56,7 @@ class InstructriceFactory
* @param (SerializerInterface&DenormalizerInterface)|null $serializer
*/
public static function create(
ProviderModel|null $defaultLlm = null,
ProviderModel|LLMConfig|null $defaultLlm = null,
LoggerInterface $logger = new NullLogger(),
?LLMFactory $llmFactory = null,
array $directories = [],
Expand Down
118 changes: 77 additions & 41 deletions src/LLM/Provider/Ollama.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

namespace AdrienBrault\Instructrice\LLM\Provider;

use AdrienBrault\Instructrice\LLM\Cost;
use AdrienBrault\Instructrice\LLM\LLMConfig;
use AdrienBrault\Instructrice\LLM\OpenAiJsonStrategy;

use function Psl\Json\encode;
use function Psl\Regex\first_match;
use function Psl\Type\int;
use function Psl\Type\shape;
use function Psl\Type\string;

enum Ollama: string implements ProviderModel
{
Expand All @@ -23,13 +26,41 @@ enum Ollama: string implements ProviderModel
case LLAMA3_8B = 'llama3:8b-instruct-';
case LLAMA3_70B = 'llama3:70b-instruct-';
case LLAMA3_70B_DOLPHIN = 'dolphin-llama3:8b-v2.9-';
case PHI3_38_128K = 'herald/phi3-128k';

public function getApiKeyEnvVar(): ?string
{
return null; // always enable
}

/**
* If you need to customize something that is not part of the arguments, instantiate the LLMConfig yourself!
*/
public static function create(
string $model,
?int $contextLength = null,
?string $label = null,
?OpenAiJsonStrategy $strategy = OpenAiJsonStrategy::JSON,
): LLMConfig {
if ($contextLength === null) {
$matchType = shape([
0 => string(),
'digits' => int(),
]);
$kDigits = first_match($model, '/\D(?P<digits>\d{1,3})k/', $matchType)['digits'] ?? 8;

$contextLength = $kDigits * 1000;
}

return new LLMConfig(
self::getURI(),
$model,
$contextLength,
$label ?? $model,
'Ollama',
strategy: $strategy,
);
}

public function createConfig(string $apiKey): LLMConfig
{
$strategy = match ($this) {
Expand All @@ -45,52 +76,46 @@ public function createConfig(string $apiKey): LLMConfig
$defaultVersion = match ($this) {
self::COMMANDRPLUS => 'q2_K_M',
self::STABLELM2_16 => 'q8_0',
self::LLAMA3_70B => 'q4_0',
self::PHI3_38_128K => '',
default => 'q4_K_M',
};

$ollamaHost = getenv('OLLAMA_HOST') ?: 'http://localhost:11434';
if (! str_starts_with($ollamaHost, 'http')) {
$ollamaHost = 'http://' . $ollamaHost;
}
$contextLength = match ($this) {
self::HERMES2PRO_MISTRAL_7B => 8000,
self::HERMES2PRO_LLAMA3_8B, self::HERMES2THETA_LLAMA3_8B => 8000,
self::DOLPHINCODER7 => 4000,
self::DOLPHINCODER15 => 4000,
self::STABLELM2_16 => 4000,
self::COMMANDR => 128000,
self::COMMANDRPLUS => 128000,
self::LLAMA3_8B, self::LLAMA3_70B_DOLPHIN, self::LLAMA3_70B => 8000,
};
$label = match ($this) {
self::HERMES2PRO_MISTRAL_7B => 'Nous Hermes 2 Pro Mistral 7B',
self::HERMES2PRO_LLAMA3_8B => 'Nous Hermes 2 Pro Llama3 8B',
self::HERMES2THETA_LLAMA3_8B => 'Nous Hermes 2 Theta Llama3 8B',
self::DOLPHINCODER7 => 'DolphinCoder 7B',
self::DOLPHINCODER15 => 'DolphinCoder 15B',
self::STABLELM2_16 => 'StableLM2 1.6B',
self::COMMANDR => 'CommandR 35B',
self::COMMANDRPLUS => 'CommandR+ 104B',
self::LLAMA3_8B => 'Llama3 8B',
self::LLAMA3_70B => 'Llama3 70B',
self::LLAMA3_70B_DOLPHIN => 'Llama3 8B Dolphin 2.9',
};
$stopTokens = match ($this) {
self::LLAMA3_8B, self::LLAMA3_70B => ["```\n\n", '<|im_end|>', '<|eot_id|>', "\t\n\t\n"],
default => null,
};

return new LLMConfig(
$ollamaHost . '/v1/chat/completions',
self::getURI(),
$this->value . $defaultVersion,
match ($this) {
self::HERMES2PRO_MISTRAL_7B => 8000,
self::HERMES2PRO_LLAMA3_8B, self::HERMES2THETA_LLAMA3_8B => 8000,
self::DOLPHINCODER7 => 4000,
self::DOLPHINCODER15 => 4000,
self::STABLELM2_16 => 4000,
self::COMMANDR => 128000,
self::COMMANDRPLUS => 128000,
self::PHI3_38_128K => 128000,
self::LLAMA3_8B, self::LLAMA3_70B_DOLPHIN, self::LLAMA3_70B => 8000,
},
match ($this) {
self::HERMES2PRO_MISTRAL_7B => 'Nous Hermes 2 Pro Mistral 7B',
self::HERMES2PRO_LLAMA3_8B => 'Nous Hermes 2 Pro Llama3 8B',
self::HERMES2THETA_LLAMA3_8B => 'Nous Hermes 2 Theta Llama3 8B',
self::DOLPHINCODER7 => 'DolphinCoder 7B',
self::DOLPHINCODER15 => 'DolphinCoder 15B',
self::STABLELM2_16 => 'StableLM2 1.6B',
self::COMMANDR => 'CommandR 35B',
self::COMMANDRPLUS => 'CommandR+ 104B',
self::LLAMA3_8B => 'Llama3 8B',
self::LLAMA3_70B => 'Llama3 70B',
self::LLAMA3_70B_DOLPHIN => 'Llama3 8B Dolphin 2.9',
self::PHI3_38_128K => 'Phi-3-Mini-128K',
},
$contextLength,
$label,
'Ollama',
Cost::create(0),
$strategy,
$systemPrompt,
stopTokens: match ($this) {
self::LLAMA3_8B, self::LLAMA3_70B => ["```\n\n", '<|im_end|>', '<|eot_id|>', "\t\n\t\n"],
default => null,
}
strategy: $strategy,
systemPrompt: $systemPrompt,
stopTokens: $stopTokens
);
}

Expand Down Expand Up @@ -124,4 +149,15 @@ private function getCommandRSystem()
PROMPT;
};
}

public static function getURI(?string $host = null): string
{
$host ??= getenv('OLLAMA_HOST') ?: 'http://localhost:11434';

if (! str_starts_with($host, 'http')) {
$host = 'http://' . $host;
}

return $host . '/v1/chat/completions';
}
}
44 changes: 44 additions & 0 deletions tests/LLM/Provider/OllamaTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
<?php

declare(strict_types=1);

namespace AdrienBrault\Instructrice\Tests\LLM\Provider;

use AdrienBrault\Instructrice\LLM\Provider\Ollama;
use PHPUnit\Framework\Attributes\CoversClass;
use PHPUnit\Framework\TestCase;

#[CoversClass(Ollama::class)]
class OllamaTest extends TestCase
{
public function testCreate128k(): void
{
$config = Ollama::create('phi3:14b-medium-128k-instruct-q5_K_M');

self::assertSame(128000, $config->contextWindow);
self::assertSame('phi3:14b-medium-128k-instruct-q5_K_M', $config->model);
}

public function testCreate4k(): void
{
$config = Ollama::create('phi3:14b-medium-4k-instruct-q5_K_M');

self::assertSame(4000, $config->contextWindow);
self::assertSame('phi3:14b-medium-4k-instruct-q5_K_M', $config->model);
}

public function testCreateWithCustomContextLength(): void
{
$config = Ollama::create('phi3:14b-medium-4k-instruct-q5_K_M', 5000);

self::assertSame(5000, $config->contextWindow);
self::assertSame('phi3:14b-medium-4k-instruct-q5_K_M', $config->model);
}

public function testCreateWithLabel(): void
{
$config = Ollama::create('model', label: 'test');

self::assertSame('test', $config->label);
}
}

0 comments on commit 9373c38

Please sign in to comment.